{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3",
      "metadata": {
        "id": "41fb78a4-5aa1-4288-9cc2-6f742062f0a3"
      },
      "source": [
        "# Fine Tuning with OpenAI"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f8d0713f-0f79-460f-8acb-47afb877d24a",
      "metadata": {
        "jp-MarkdownHeadingCollapsed": true,
        "id": "f8d0713f-0f79-460f-8acb-47afb877d24a"
      },
      "source": [
        "## Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2cdfe762-3200-4459-981e-0ded7c14b4de",
      "metadata": {
        "id": "2cdfe762-3200-4459-981e-0ded7c14b4de"
      },
      "outputs": [],
      "source": [
        "# Constants - used for printing to stdout in color\n",
        "\n",
        "GREEN = \"\\033[92m\"\n",
        "YELLOW = \"\\033[93m\"\n",
        "RED = \"\\033[91m\"\n",
        "RESET = \"\\033[0m\"\n",
        "COLOR_MAP = {\"red\":RED, \"orange\": YELLOW, \"green\": GREEN}"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d9f325d5-fb67-475c-aca0-01c0f0ea5ec1",
      "metadata": {
        "jp-MarkdownHeadingCollapsed": true,
        "id": "d9f325d5-fb67-475c-aca0-01c0f0ea5ec1"
      },
      "source": [
        "### Item"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0832e74b-2779-4822-8e6c-4361ec165c7f",
      "metadata": {
        "id": "0832e74b-2779-4822-8e6c-4361ec165c7f"
      },
      "outputs": [],
      "source": [
        "from typing import Optional\n",
        "from transformers import AutoTokenizer\n",
        "import re\n",
        "\n",
        "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
        "\n",
        "MIN_TOKENS = 150 # Any less than this, and we don't have enough useful content\n",
        "MAX_TOKENS = 160 # Truncate after this many tokens. Then after adding in prompt text, we will get to around 180 tokens\n",
        "\n",
        "MIN_CHARS = 300\n",
        "CEILING_CHARS = MAX_TOKENS * 7\n",
        "\n",
        "class Item:\n",
        "    \"\"\"\n",
        "    An Item is a cleaned, curated datapoint of a Product with a Price\n",
        "    \"\"\"\n",
        "\n",
        "    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
        "    PREFIX = \"Price is $\"\n",
        "    QUESTION = \"How much does this cost to the nearest dollar?\"\n",
        "    REMOVALS = ['\"Batteries Included?\": \"No\"', '\"Batteries Included?\": \"Yes\"', '\"Batteries Required?\": \"No\"', '\"Batteries Required?\": \"Yes\"', \"By Manufacturer\", \"Item\", \"Date First\", \"Package\", \":\", \"Number of\", \"Best Sellers\", \"Number\", \"Product \"]\n",
        "\n",
        "    title: str\n",
        "    price: float\n",
        "    category: str\n",
        "    token_count: int = 0\n",
        "    details: Optional[str]\n",
        "    prompt: Optional[str] = None\n",
        "    include = False\n",
        "\n",
        "    def __init__(self, data, price):\n",
        "        self.title = data['title']\n",
        "        self.price = price\n",
        "        self.parse(data)\n",
        "\n",
        "    def scrub_details(self):\n",
        "        \"\"\"\n",
        "        Clean up the details string by removing common text that doesn't add value\n",
        "        \"\"\"\n",
        "        details = self.details\n",
        "        for remove in self.REMOVALS:\n",
        "            details = details.replace(remove, \"\")\n",
        "        return details\n",
        "\n",
        "    def scrub(self, stuff):\n",
        "        \"\"\"\n",
        "        Clean up the provided text by removing unnecessary characters and whitespace\n",
        "        Also remove words that are 7+ chars and contain numbers, as these are likely irrelevant product numbers\n",
        "        \"\"\"\n",
        "        stuff = re.sub(r'[:\\[\\]\"{}【】\\s]+', ' ', stuff).strip()\n",
        "        stuff = stuff.replace(\" ,\", \",\").replace(\",,,\",\",\").replace(\",,\",\",\")\n",
        "        words = stuff.split(' ')\n",
        "        select = [word for word in words if len(word)<7 or not any(char.isdigit() for char in word)]\n",
        "        return \" \".join(select)\n",
        "\n",
        "    def parse(self, data):\n",
        "        \"\"\"\n",
        "        Parse this datapoint and if it fits within the allowed Token range,\n",
        "        then set include to True\n",
        "        \"\"\"\n",
        "        contents = '\\n'.join(data['description'])\n",
        "        if contents:\n",
        "            contents += '\\n'\n",
        "        features = '\\n'.join(data['features'])\n",
        "        if features:\n",
        "            contents += features + '\\n'\n",
        "        self.details = data['details']\n",
        "        if self.details:\n",
        "            contents += self.scrub_details() + '\\n'\n",
        "        if len(contents) > MIN_CHARS:\n",
        "            contents = contents[:CEILING_CHARS]\n",
        "            text = f\"{self.scrub(self.title)}\\n{self.scrub(contents)}\"\n",
        "            tokens = self.tokenizer.encode(text, add_special_tokens=False)\n",
        "            if len(tokens) > MIN_TOKENS:\n",
        "                tokens = tokens[:MAX_TOKENS]\n",
        "                text = self.tokenizer.decode(tokens)\n",
        "                self.make_prompt(text)\n",
        "                self.include = True\n",
        "\n",
        "    def make_prompt(self, text):\n",
        "        \"\"\"\n",
        "        Set the prompt instance variable to be a prompt appropriate for training\n",
        "        \"\"\"\n",
        "        self.prompt = f\"{self.QUESTION}\\n\\n{text}\\n\\n\"\n",
        "        self.prompt += f\"{self.PREFIX}{str(round(self.price))}.00\"\n",
        "        self.token_count = len(self.tokenizer.encode(self.prompt, add_special_tokens=False))\n",
        "\n",
        "    def test_prompt(self):\n",
        "        \"\"\"\n",
        "        Return a prompt suitable for testing, with the actual price removed\n",
        "        \"\"\"\n",
        "        return self.prompt.split(self.PREFIX)[0] + self.PREFIX\n",
        "\n",
        "    def __repr__(self):\n",
        "        \"\"\"\n",
        "        Return a String version of this Item\n",
        "        \"\"\"\n",
        "        return f\"<{self.title} = ${self.price}>\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Tester"
      ],
      "metadata": {
        "id": "LaIwYGzItsEi"
      },
      "id": "LaIwYGzItsEi"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "129470d7-a5b1-4851-8800-970cccc8bcf5",
      "metadata": {
        "id": "129470d7-a5b1-4851-8800-970cccc8bcf5"
      },
      "outputs": [],
      "source": [
        "class Tester:\n",
        "\n",
        "    def __init__(self, predictor, data, title=None, size=250):\n",
        "        self.predictor = predictor\n",
        "        self.data = data\n",
        "        self.title = title or predictor.__name__.replace(\"_\", \" \").title()\n",
        "        self.size = size\n",
        "        self.guesses = []\n",
        "        self.truths = []\n",
        "        self.errors = []\n",
        "        self.sles = []\n",
        "        self.colors = []\n",
        "\n",
        "    def color_for(self, error, truth):\n",
        "        if error<40 or error/truth < 0.2:\n",
        "            return \"green\"\n",
        "        elif error<80 or error/truth < 0.4:\n",
        "            return \"orange\"\n",
        "        else:\n",
        "            return \"red\"\n",
        "\n",
        "    def run_datapoint(self, i):\n",
        "        datapoint = self.data[i]\n",
        "        guess = self.predictor(datapoint)\n",
        "        truth = datapoint.price\n",
        "        error = abs(guess - truth)\n",
        "        log_error = math.log(truth+1) - math.log(guess+1)\n",
        "        sle = log_error ** 2\n",
        "        color = self.color_for(error, truth)\n",
        "        title = datapoint.title if len(datapoint.title) <= 40 else datapoint.title[:40]+\"...\"\n",
        "        self.guesses.append(guess)\n",
        "        self.truths.append(truth)\n",
        "        self.errors.append(error)\n",
        "        self.sles.append(sle)\n",
        "        self.colors.append(color)\n",
        "        print(f\"{COLOR_MAP[color]}{i+1}: Guess: ${guess:,.2f} Truth: ${truth:,.2f} Error: ${error:,.2f} SLE: {sle:,.2f} Item: {title}{RESET}\")\n",
        "\n",
        "    def chart(self, title):\n",
        "        max_error = max(self.errors)\n",
        "        plt.figure(figsize=(12, 8))\n",
        "        max_val = max(max(self.truths), max(self.guesses))\n",
        "        plt.plot([0, max_val], [0, max_val], color='deepskyblue', lw=2, alpha=0.6)\n",
        "        plt.scatter(self.truths, self.guesses, s=3, c=self.colors)\n",
        "        plt.xlabel('Ground Truth')\n",
        "        plt.ylabel('Model Estimate')\n",
        "        plt.xlim(0, max_val)\n",
        "        plt.ylim(0, max_val)\n",
        "        plt.title(title)\n",
        "        plt.show()\n",
        "\n",
        "    def report(self):\n",
        "        average_error = sum(self.errors) / self.size\n",
        "        rmsle = math.sqrt(sum(self.sles) / self.size)\n",
        "        hits = sum(1 for color in self.colors if color==\"green\")\n",
        "        title = f\"{self.title} Error=${average_error:,.2f} RMSLE={rmsle:,.2f} Hits={hits/self.size*100:.1f}%\"\n",
        "        self.chart(title)\n",
        "\n",
        "    def run(self):\n",
        "        self.error = 0\n",
        "        for i in range(self.size):\n",
        "            self.run_datapoint(i)\n",
        "        self.report()\n",
        "\n",
        "    @classmethod\n",
        "    def test(cls, function, data):\n",
        "        cls(function, data).run()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# A utility function to extract the price from a string\n",
        "\n",
        "def get_price(s):\n",
        "    s = s.replace('$','').replace(',','')\n",
        "    match = re.search(r'[-+]?\\d*\\.?\\d+', s) # Simplify regex\n",
        "    return float(match.group()) if match else 0"
      ],
      "metadata": {
        "id": "6XywRUiUro69"
      },
      "id": "6XywRUiUro69",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "id": "10af1228-30b7-4dfc-a364-059ea099af81",
      "metadata": {
        "id": "10af1228-30b7-4dfc-a364-059ea099af81"
      },
      "source": [
        "## Data Curation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5faa087c-bdf7-42e5-9c32-c0b0a4d4160f",
      "metadata": {
        "id": "5faa087c-bdf7-42e5-9c32-c0b0a4d4160f"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --quiet  jupyterlab ipython ipywidgets huggingface_hub datasets transformers\n",
        "\n",
        "%matplotlib notebook\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Load Dataset from Hugging Face"
      ],
      "metadata": {
        "id": "3XTxVhq0xC8Z"
      },
      "id": "3XTxVhq0xC8Z"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22",
      "metadata": {
        "id": "2bd6fc25-77c4-47a6-a2d2-ce80403f3c22"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset, Dataset, DatasetDict\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "\n",
        "dataset = load_dataset('ranskills/Amazon-Reviews-2023-raw_meta_All_Beauty', split='full')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d",
      "metadata": {
        "id": "b66b59c2-80b2-4d47-b739-c59423cf9d7d"
      },
      "outputs": [],
      "source": [
        "from IPython.display import display, JSON\n",
        "\n",
        "\n",
        "print(f'Number of datapoints: {dataset.num_rows:,}')\n",
        "display(JSON(dataset.features.to_dict()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b",
      "metadata": {
        "id": "e9620ed3-205e-48ee-b67a-e56b30bf6b6b"
      },
      "outputs": [],
      "source": [
        "def non_zero_price_filter(datapoint: dict):\n",
        "    try:\n",
        "        price = float(datapoint['price'])\n",
        "        return price > 0\n",
        "    except:\n",
        "        return False\n",
        "\n",
        "filtered_dataset = dataset.filter(non_zero_price_filter)\n",
        "\n",
        "print(f'Prices with non-zero prices:{filtered_dataset.num_rows:,}  = {filtered_dataset.num_rows / dataset.num_rows * 100:,.2f}%')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585",
      "metadata": {
        "id": "834a3c4b-fc9c-4bc7-b6b9-bdf7e8d6d585"
      },
      "outputs": [],
      "source": [
        "from collections import defaultdict\n",
        "\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "data = defaultdict(lambda: [])\n",
        "for datapoint in filtered_dataset:\n",
        "    price = float(datapoint['price'])\n",
        "    contents = datapoint[\"title\"] + str(datapoint[\"description\"]) + str(datapoint[\"features\"]) + str(datapoint[\"details\"])\n",
        "\n",
        "    data['price'].append(price)\n",
        "    data['characters'].append(len(contents))\n",
        "\n",
        "%matplotlib inline\n",
        "\n",
        "df = pd.DataFrame(data)\n",
        "\n",
        "combined_describe = pd.concat(\n",
        "    [df['price'].describe(), df['characters'].describe()],\n",
        "    axis=1\n",
        ")\n",
        "\n",
        "display(combined_describe)\n",
        "\n",
        "prices = data['price']\n",
        "lengths = data['characters']\n",
        "\n",
        "plt.figure(figsize=(15, 6))\n",
        "plt.title(f\"Prices: Avg {df['price'].mean():,.2f} and highest {df['price'].max():,}\\n\")\n",
        "plt.xlabel('Length (chars)')\n",
        "plt.ylabel('Count')\n",
        "plt.hist(prices, rwidth=0.7, color=\"orange\", bins=range(0, 300, 10))\n",
        "plt.show()\n",
        "\n",
        "plt.figure(figsize=(15, 6))\n",
        "plt.title(f\"Characters: Avg {sum(lengths)/len(lengths):,.0f} and highest {max(lengths):,}\\n\")\n",
        "plt.xlabel('Length (characters)')\n",
        "plt.ylabel('Count')\n",
        "plt.hist(lengths, rwidth=0.7, color=\"lightblue\", bins=range(0, 2500, 50))\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a506f42c-81c0-4198-bc0b-1e0653620be8",
      "metadata": {
        "id": "a506f42c-81c0-4198-bc0b-1e0653620be8"
      },
      "outputs": [],
      "source": [
        "BASE_MODEL = 'meta-llama/Meta-Llama-3.1-8B'\n",
        "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)\n",
        "\n",
        "tokenizer.encode('114', add_special_tokens=False)\n",
        "\n",
        "items = []\n",
        "for datapoint in filtered_dataset:\n",
        "    price = float(datapoint['price'])\n",
        "    items.append(Item(datapoint, price))\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5842ace6-332d-46da-a853-5ea5a2a1cf88",
      "metadata": {
        "id": "5842ace6-332d-46da-a853-5ea5a2a1cf88"
      },
      "outputs": [],
      "source": [
        "print(items[0].test_prompt())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "42ee0099-0d2a-4331-a01c-3462363a6987",
      "metadata": {
        "id": "42ee0099-0d2a-4331-a01c-3462363a6987"
      },
      "outputs": [],
      "source": [
        "# filter out items with None prompt as a result of their content being below the minimum threshold\n",
        "valid_items = [item for item in items if item.prompt is not None]\n",
        "\n",
        "data_size = len(valid_items)\n",
        "\n",
        "\n",
        "training_size = int(data_size * 0.9)\n",
        "train = valid_items[:training_size]\n",
        "test = valid_items[training_size:]\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b",
      "metadata": {
        "id": "1146d5a2-f93e-4fe9-864e-4ce7e01e257b"
      },
      "outputs": [],
      "source": [
        "train_prompts = [item.prompt for item in train]\n",
        "train_prices = [item.price for item in train]\n",
        "test_prompts = [item.test_prompt() for item in test]\n",
        "test_prices = [item.price for item in test]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "31ca360d-5fc6-487a-91c6-d61758b2ff16",
      "metadata": {
        "id": "31ca360d-5fc6-487a-91c6-d61758b2ff16"
      },
      "outputs": [],
      "source": [
        "# Create a Dataset from the lists\n",
        "\n",
        "train_dataset = Dataset.from_dict({\"text\": train_prompts, \"price\": train_prices})\n",
        "test_dataset = Dataset.from_dict({\"text\": test_prompts, \"price\": test_prices})\n",
        "dataset = DatasetDict({\n",
        "    \"train\": train_dataset,\n",
        "    \"test\": test_dataset\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8",
      "metadata": {
        "id": "05e6ca7e-bf40-49f9-bffb-a5b22e5800d8"
      },
      "source": [
        "### Export Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c",
      "metadata": {
        "id": "b0ff2fe3-78bf-49e3-a682-6a46742d010c"
      },
      "outputs": [],
      "source": [
        "import pickle\n",
        "\n",
        "DATA_DIR = 'data'\n",
        "\n",
        "train_storage_file = lambda ext: f'{DATA_DIR}/all_beauty_train{ext}'\n",
        "test_storage_file = lambda ext: f'{DATA_DIR}/all_beauty_test{ext}'\n",
        "\n",
        "with open(train_storage_file('.pkl'), 'wb') as file:\n",
        "    pickle.dump(train, file)\n",
        "\n",
        "with open(test_storage_file('.pkl'), 'wb') as file:\n",
        "    pickle.dump(test, file)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b2164662-9bc9-4a66-9e4e-a8a955a45753",
      "metadata": {
        "id": "b2164662-9bc9-4a66-9e4e-a8a955a45753"
      },
      "outputs": [],
      "source": [
        "dataset['train'].to_parquet(train_storage_file('.parquet'))\n",
        "dataset['test'].to_parquet(test_storage_file('.parquet'))\n",
        "\n",
        "# How to load back the data\n",
        "# loaded_dataset = load_dataset(\"parquet\", data_files='amazon_polarity_train.parquet')"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013",
      "metadata": {
        "id": "6fe428a2-41c4-4f7f-a43f-e8ba2f344013"
      },
      "source": [
        "### Predictions"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Random Pricer"
      ],
      "metadata": {
        "id": "qX0c_prppnyZ"
      },
      "id": "qX0c_prppnyZ"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7323252b-db50-4b8a-a7fc-8504bb3d218b",
      "metadata": {
        "id": "7323252b-db50-4b8a-a7fc-8504bb3d218b"
      },
      "outputs": [],
      "source": [
        "import random\n",
        "import math\n",
        "\n",
        "\n",
        "def random_pricer(item):\n",
        "    return random.randrange(1,200)\n",
        "\n",
        "random.seed(42)\n",
        "\n",
        "# Run our TestRunner\n",
        "Tester.test(random_pricer, test)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Constant Pricer"
      ],
      "metadata": {
        "id": "O0xVXRXkp9sQ"
      },
      "id": "O0xVXRXkp9sQ"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6a932b0e-ba6e-45d2-8436-b740c3681272",
      "metadata": {
        "id": "6a932b0e-ba6e-45d2-8436-b740c3681272"
      },
      "outputs": [],
      "source": [
        "training_prices = [item.price for item in train]\n",
        "training_average = sum(training_prices) / len(training_prices)\n",
        "\n",
        "def constant_pricer(item):\n",
        "    return training_average\n",
        "\n",
        "Tester.test(constant_pricer, test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d3410bd4-98e4-42a6-a702-4423cfd034b4",
      "metadata": {
        "id": "d3410bd4-98e4-42a6-a702-4423cfd034b4"
      },
      "outputs": [],
      "source": [
        "train[0].details"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "44537051-7b4e-4b8c-95a7-a989ea51e517",
      "metadata": {
        "id": "44537051-7b4e-4b8c-95a7-a989ea51e517"
      },
      "source": [
        "### Prepare Fine-Tuning Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec",
      "metadata": {
        "id": "47d03b0b-4a93-4f9d-80ac-10f3fc11ccec"
      },
      "outputs": [],
      "source": [
        "fine_tune_train = train[:100]\n",
        "fine_tune_validation = train[100:125]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4d7b6f35-890c-4227-8990-6b62694a332d",
      "metadata": {
        "id": "4d7b6f35-890c-4227-8990-6b62694a332d"
      },
      "outputs": [],
      "source": [
        "def messages_for(item):\n",
        "    system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
        "    user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
        "    return [\n",
        "        {\"role\": \"system\", \"content\": system_message},\n",
        "        {\"role\": \"user\", \"content\": user_prompt},\n",
        "        {\"role\": \"assistant\", \"content\": f\"Price is ${item.price:.2f}\"}\n",
        "    ]\n",
        "\n",
        "messages_for(train[0])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "1a6e06f3-614f-4687-bd43-9ac03aaface8",
      "metadata": {
        "id": "1a6e06f3-614f-4687-bd43-9ac03aaface8"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "from pathlib import Path\n",
        "DATA_DIR = 'data'\n",
        "\n",
        "data_path = Path(DATA_DIR)\n",
        "\n",
        "def make_jsonl(items):\n",
        "    result = \"\"\n",
        "    for item in items:\n",
        "        messages = messages_for(item)\n",
        "        messages_str = json.dumps(messages)\n",
        "        result += '{\"messages\": ' + messages_str +'}\\n'\n",
        "    return result.strip()\n",
        "\n",
        "# print(make_jsonl(train[:3]))\n",
        "data_path.absolute()\n",
        "if not data_path.exists():\n",
        "    data_path.mkdir(parents=True)\n",
        "\n",
        "\n",
        "\n",
        "train_jsonl_path = f'{data_path}/pricer_train.jsonl'\n",
        "validation_jsonl_path = f'{data_path}/pricer_validation.jsonl'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42",
      "metadata": {
        "id": "d8dda552-8003-4fdc-b36a-7d0afa9b0b42"
      },
      "outputs": [],
      "source": [
        "def write_jsonl(items, filename):\n",
        "    with open(filename, \"w\") as f:\n",
        "        jsonl = make_jsonl(items)\n",
        "        f.write(jsonl)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "189e959c-d70c-4509-bff6-1cbd8e8db637",
      "metadata": {
        "id": "189e959c-d70c-4509-bff6-1cbd8e8db637"
      },
      "outputs": [],
      "source": [
        "\n",
        "write_jsonl(fine_tune_train, train_jsonl_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2",
      "metadata": {
        "id": "6b1480e2-ed19-4d0e-bc5d-a00086d104a2"
      },
      "outputs": [],
      "source": [
        "write_jsonl(fine_tune_validation, validation_jsonl_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Training"
      ],
      "metadata": {
        "id": "ga-f4JK7sPU2"
      },
      "id": "ga-f4JK7sPU2"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "de958a51-69ba-420c-84b7-d32765898fd2",
      "metadata": {
        "id": "de958a51-69ba-420c-84b7-d32765898fd2"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from openai import OpenAI\n",
        "from dotenv import load_dotenv\n",
        "from google.colab import userdata\n",
        "\n",
        "load_dotenv()\n",
        "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n",
        "\n",
        "openai = OpenAI()"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "with open(train_jsonl_path, 'rb') as f:\n",
        "    train_file = openai.files.create(file=f, purpose='fine-tune')"
      ],
      "metadata": {
        "id": "QFDAoNnoRCk1"
      },
      "id": "QFDAoNnoRCk1",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_file"
      ],
      "metadata": {
        "id": "kBVWisusQwDq"
      },
      "id": "kBVWisusQwDq",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "with open(validation_jsonl_path, 'rb') as f:\n",
        "    validation_file = openai.files.create(file=f, purpose='fine-tune')\n",
        "\n",
        "validation_file"
      ],
      "metadata": {
        "id": "wgth1KvMSEOb"
      },
      "id": "wgth1KvMSEOb",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "wandb_integration = {\"type\": \"wandb\", \"wandb\": {\"project\": \"gpt-pricer\"}}"
      ],
      "metadata": {
        "id": "-ohEia37Sjtx"
      },
      "id": "-ohEia37Sjtx",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "openai.fine_tuning.jobs.create(\n",
        "    training_file=train_file.id,\n",
        "    validation_file=validation_file.id,\n",
        "    model=\"gpt-4o-mini-2024-07-18\",\n",
        "    seed=42,\n",
        "    hyperparameters={\"n_epochs\": 1},\n",
        "    integrations = [wandb_integration],\n",
        "    suffix=\"pricer\"\n",
        ")"
      ],
      "metadata": {
        "id": "g7uz8SC5S3_s"
      },
      "id": "g7uz8SC5S3_s",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "openai.fine_tuning.jobs.list(limit=1)"
      ],
      "metadata": {
        "id": "_zHswJwzWCHZ"
      },
      "id": "_zHswJwzWCHZ",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "job_id = openai.fine_tuning.jobs.list(limit=1).data[0].id\n",
        "job_id"
      ],
      "metadata": {
        "id": "rSHYkQojWH8Q"
      },
      "id": "rSHYkQojWH8Q",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "openai.fine_tuning.jobs.retrieve(job_id)"
      ],
      "metadata": {
        "id": "Yqq-jd1yWMuO"
      },
      "id": "Yqq-jd1yWMuO",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job_id, limit=10).data"
      ],
      "metadata": {
        "id": "37BH0u-QWOiY"
      },
      "id": "37BH0u-QWOiY",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import wandb\n",
        "from wandb.integration.openai.fine_tuning import WandbLogger\n",
        "\n",
        "\n",
        "wandb.login()\n",
        "# Sync the fine-tuning job with Weights & Biases.\n",
        "WandbLogger.sync(fine_tune_job_id=job_id, project=\"gpt-pricer\")"
      ],
      "metadata": {
        "id": "2nNSE_AzWYMq"
      },
      "id": "2nNSE_AzWYMq",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fine_tuned_model_name = openai.fine_tuning.jobs.retrieve(job_id).fine_tuned_model\n",
        "fine_tuned_model_name"
      ],
      "metadata": {
        "id": "ASiJUw-Fh8Ul"
      },
      "id": "ASiJUw-Fh8Ul",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def messages_for(item):\n",
        "    system_message = \"You estimate prices of items. Reply only with the price, no explanation\"\n",
        "    user_prompt = item.test_prompt().replace(\" to the nearest dollar\",\"\").replace(\"\\n\\nPrice is $\",\"\")\n",
        "    return [\n",
        "        {\"role\": \"system\", \"content\": system_message},\n",
        "        {\"role\": \"user\", \"content\": user_prompt},\n",
        "        {\"role\": \"assistant\", \"content\": \"Price is $\"}\n",
        "    ]"
      ],
      "metadata": {
        "id": "7jB_7gqBiH_r"
      },
      "id": "7jB_7gqBiH_r",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# The function for gpt-4o-mini\n",
        "\n",
        "def gpt_fine_tuned(item):\n",
        "    response = openai.chat.completions.create(\n",
        "        model=fine_tuned_model_name,\n",
        "        messages=messages_for(item),\n",
        "        seed=42,\n",
        "        max_tokens=7\n",
        "    )\n",
        "    reply = response.choices[0].message.content\n",
        "    return get_price(reply)"
      ],
      "metadata": {
        "id": "BHfLSadhiVQE"
      },
      "id": "BHfLSadhiVQE",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "print(test[0].price)\n",
        "print(gpt_fine_tuned(test[0]))"
      ],
      "metadata": {
        "id": "C0CiTZ4jkjrI"
      },
      "id": "C0CiTZ4jkjrI",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "Tester.test(gpt_fine_tuned, test)"
      ],
      "metadata": {
        "id": "WInQE0ObkuBl"
      },
      "id": "WInQE0ObkuBl",
      "execution_count": null,
      "outputs": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "sagemaker-distribution:Python",
      "language": "python",
      "name": "conda-env-sagemaker-distribution-py"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.9"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}